<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - model_selection_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*

    This is an example that shows how you can perform model selection with the
    dlib C++ Library.  

    It will create a simple dataset and show you how to use cross validation and
    global optimization to determine good parameters for the purpose of training
    an svm to classify the data.

    The data used in this example will be 2 dimensional data and will come from a
    distribution where points with a distance less than 10 from the origin are
    labeled +1 and all other points are labeled as -1.
        

    As an side, you should probably read the <a href="svm_ex.cpp.html">svm_ex.cpp</a> and <a href="matrix_ex.cpp.html">matrix_ex.cpp</a> example
    programs before you read this one.
*/</font>


<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>svm.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>global_optimization.h<font color='#5555FF'>&gt;</font>

<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;


<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
    <font color='#009900'>// The svm functions use column vectors to contain a lot of the data on which they 
</font>    <font color='#009900'>// operate. So the first thing we do here is declare a convenient typedef.  
</font>
    <font color='#009900'>// This typedef declares a matrix with 2 rows and 1 column.  It will be the
</font>    <font color='#009900'>// object that contains each of our 2 dimensional samples.   
</font>    <font color='#0000FF'>typedef</font> matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font>, <font color='#979000'>2</font>, <font color='#979000'>1</font><font color='#5555FF'>&gt;</font> sample_type;



    <font color='#009900'>// Now we make objects to contain our samples and their respective labels.
</font>    std::vector<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> samples;
    std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> labels;

    <font color='#009900'>// Now let's put some data into our samples and labels objects.  We do this
</font>    <font color='#009900'>// by looping over a bunch of points and labeling them according to their
</font>    <font color='#009900'>// distance from the origin.
</font>    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> r <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>20</font>; r <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>20</font>; r <font color='#5555FF'>+</font><font color='#5555FF'>=</font> <font color='#979000'>0.8</font><font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font> c <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>20</font>; c <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>20</font>; c <font color='#5555FF'>+</font><font color='#5555FF'>=</font> <font color='#979000'>0.8</font><font face='Lucida Console'>)</font>
        <b>{</b>
            sample_type samp;
            <font color='#BB00BB'>samp</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> r;
            <font color='#BB00BB'>samp</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> c;
            samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>samp<font face='Lucida Console'>)</font>;

            <font color='#009900'>// if this point is less than 10 from the origin
</font>            <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>sqrt</font><font face='Lucida Console'>(</font>r<font color='#5555FF'>*</font>r <font color='#5555FF'>+</font> c<font color='#5555FF'>*</font>c<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>=</font> <font color='#979000'>10</font><font face='Lucida Console'>)</font>
                labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#5555FF'>+</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
            <font color='#0000FF'>else</font>
                labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
        <b>}</b>
    <b>}</b>

    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Generated </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> points</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;


    <font color='#009900'>// Here we normalize all the samples by subtracting their mean and dividing by their
</font>    <font color='#009900'>// standard deviation.  This is generally a good idea since it often heads off
</font>    <font color='#009900'>// numerical stability problems and also prevents one large feature from smothering
</font>    <font color='#009900'>// others.  Doing this doesn't matter much in this example so I'm just doing this here
</font>    <font color='#009900'>// so you can see an easy way to accomplish this with the library.  
</font>    vector_normalizer<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> normalizer;
    <font color='#009900'>// let the normalizer learn the mean and standard deviation of the samples
</font>    normalizer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>samples<font face='Lucida Console'>)</font>;
    <font color='#009900'>// now normalize each sample
</font>    <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
        samples[i] <font color='#5555FF'>=</font> <font color='#BB00BB'>normalizer</font><font face='Lucida Console'>(</font>samples[i]<font face='Lucida Console'>)</font>; 


    <font color='#009900'>// Now that we have some data we want to train on it.  We are going to train a
</font>    <font color='#009900'>// binary SVM with the RBF kernel to classify the data.  However, there are
</font>    <font color='#009900'>// three parameters to the training.  These are the SVM C parameters for each
</font>    <font color='#009900'>// class and the RBF kernel's gamma parameter.  Our choice for these
</font>    <font color='#009900'>// parameters will influence how good the resulting decision function is.  To
</font>    <font color='#009900'>// test how good a particular choice of these parameters is we can use the
</font>    <font color='#009900'>// cross_validate_trainer() function to perform n-fold cross validation on our
</font>    <font color='#009900'>// training data.  However, there is a problem with the way we have sampled
</font>    <font color='#009900'>// our distribution above.  The problem is that there is a definite ordering
</font>    <font color='#009900'>// to the samples.  That is, the first half of the samples look like they are
</font>    <font color='#009900'>// from a different distribution than the second half.  This would screw up
</font>    <font color='#009900'>// the cross validation process, but we can fix it by randomizing the order of
</font>    <font color='#009900'>// the samples with the following function call.
</font>    <font color='#BB00BB'>randomize_samples</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;


    <font color='#009900'>// And now we get to the important bit.  Here we define a function,
</font>    <font color='#009900'>// cross_validation_score(), that will do the cross-validation we
</font>    <font color='#009900'>// mentioned and return a number indicating how good a particular setting
</font>    <font color='#009900'>// of gamma, c1, and c2 is.
</font>    <font color='#0000FF'>auto</font> cross_validation_score <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> gamma, <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> c1, <font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> c2<font face='Lucida Console'>)</font> 
    <b>{</b>
        <font color='#009900'>// Make a RBF SVM trainer and tell it what the parameters are supposed to be.
</font>        <font color='#0000FF'>typedef</font> radial_basis_kernel<font color='#5555FF'>&lt;</font>sample_type<font color='#5555FF'>&gt;</font> kernel_type;
        svm_c_trainer<font color='#5555FF'>&lt;</font>kernel_type<font color='#5555FF'>&gt;</font> trainer;
        trainer.<font color='#BB00BB'>set_kernel</font><font face='Lucida Console'>(</font><font color='#BB00BB'>kernel_type</font><font face='Lucida Console'>(</font>gamma<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
        trainer.<font color='#BB00BB'>set_c_class1</font><font face='Lucida Console'>(</font>c1<font face='Lucida Console'>)</font>;
        trainer.<font color='#BB00BB'>set_c_class2</font><font face='Lucida Console'>(</font>c2<font face='Lucida Console'>)</font>;

        <font color='#009900'>// Finally, perform 10-fold cross validation and then print and return the results.
</font>        matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font> result <font color='#5555FF'>=</font> <font color='#BB00BB'>cross_validate_trainer</font><font face='Lucida Console'>(</font>trainer, samples, labels, <font color='#979000'>10</font><font face='Lucida Console'>)</font>;
        cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>gamma: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>setw</font><font face='Lucida Console'>(</font><font color='#979000'>11</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> gamma <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>  c1: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>setw</font><font face='Lucida Console'>(</font><font color='#979000'>11</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> c1 <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>  "<font color='#CC0000'>  c2: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#BB00BB'>setw</font><font face='Lucida Console'>(</font><font color='#979000'>11</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> c2 <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font>  "<font color='#CC0000'>  cross validation accuracy: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> result;

        <font color='#009900'>// Now return a number indicating how good the parameters are.  Bigger is
</font>        <font color='#009900'>// better in this example.  Here I'm returning the harmonic mean between the
</font>        <font color='#009900'>// accuracies of each class.  However, you could do something else.  For
</font>        <font color='#009900'>// example, you might care a lot more about correctly predicting the +1 class,
</font>        <font color='#009900'>// so you could penalize results that didn't obtain a high accuracy on that
</font>        <font color='#009900'>// class.  You might do this by using something like a weighted version of the
</font>        <font color='#009900'>// F1-score (see http://en.wikipedia.org/wiki/F1_score).     
</font>        <font color='#0000FF'>return</font> <font color='#979000'>2</font><font color='#5555FF'>*</font><font color='#BB00BB'>prod</font><font face='Lucida Console'>(</font>result<font face='Lucida Console'>)</font><font color='#5555FF'>/</font><font color='#BB00BB'>sum</font><font face='Lucida Console'>(</font>result<font face='Lucida Console'>)</font>;
    <b>}</b>;


    <font color='#009900'>// And finally, we call this global optimizer that will search for the best parameters.
</font>    <font color='#009900'>// It will call cross_validation_score() 50 times with different settings and return
</font>    <font color='#009900'>// the best parameter setting it finds.  find_max_global() uses a global optimization
</font>    <font color='#009900'>// method based on a combination of non-parametric global function modeling and
</font>    <font color='#009900'>// quadratic trust region modeling to efficiently find a global maximizer.  It usually
</font>    <font color='#009900'>// does a good job with a relatively small number of calls to cross_validation_score().
</font>    <font color='#009900'>// In this example, you should observe that it finds settings that give perfect binary
</font>    <font color='#009900'>// classification of the data.
</font>    <font color='#0000FF'>auto</font> result <font color='#5555FF'>=</font> <font color='#BB00BB'>find_max_global</font><font face='Lucida Console'>(</font>cross_validation_score, 
                                  <b>{</b><font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font>, <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font>, <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>5</font><b>}</b>,  <font color='#009900'>// lower bound constraints on gamma, c1, and c2, respectively
</font>                                  <b>{</b><font color='#979000'>100</font>,  <font color='#979000'>1e6</font>,  <font color='#979000'>1e6</font><b>}</b>,   <font color='#009900'>// upper bound constraints on gamma, c1, and c2, respectively
</font>                                  <font color='#BB00BB'>max_function_calls</font><font face='Lucida Console'>(</font><font color='#979000'>50</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;

    <font color='#0000FF'><u>double</u></font> best_gamma <font color='#5555FF'>=</font> result.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'><u>double</u></font> best_c1    <font color='#5555FF'>=</font> result.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>;
    <font color='#0000FF'><u>double</u></font> best_c2    <font color='#5555FF'>=</font> result.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>;

    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> best cross-validation score: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> result.y <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> best gamma: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> best_gamma <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>   best c1: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> best_c1 <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>    best c2: </font>"<font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> best_c2  <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>
<font color='#0000FF'>catch</font> <font face='Lucida Console'>(</font>exception<font color='#5555FF'>&amp;</font> e<font face='Lucida Console'>)</font>
<b>{</b>
    cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>


</pre></body></html>